#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#

import copy
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple, Union

from airbyte_cdk.models import AirbyteMessage, AirbyteStateBlob, AirbyteStateMessage, AirbyteStateType, AirbyteStreamState, StreamDescriptor
from airbyte_cdk.models import Type as MessageType
from pydantic import ConfigDict as V2ConfigDict


class HashableStreamDescriptor(StreamDescriptor):
    """
    Helper class that overrides the existing StreamDescriptor class that is auto generated from the Airbyte Protocol and
    freezes its fields so that it be used as a hash key. This is only marked public because we use it outside for unit tests.
    """

    model_config = V2ConfigDict(extra="allow", frozen=True)


class ConnectorStateManager:
    """
    ConnectorStateManager consolidates the various forms of a stream's incoming state message (STREAM / GLOBAL) under a common
    interface. It also provides methods to extract and update state
    """

    def __init__(self, state: Optional[List[AirbyteStateMessage]] = None):
        shared_state, per_stream_states = self._extract_from_state_message(state)

        # We explicitly throw an error if we receive a GLOBAL state message that contains a shared_state because API sources are
        # designed to checkpoint state independently of one another. API sources should never be emitting a state message where
        # shared_state is populated. Rather than define how to handle shared_state without a clear use case, we're opting to throw an
        # error instead and if/when we find one, we will then implement processing of the shared_state value.
        if shared_state:
            raise ValueError(
                "Received a GLOBAL AirbyteStateMessage that contains a shared_state. This library only ever generates per-STREAM "
                "STATE messages so this was not generated by this connector. This must be an orchestrator or platform error. GLOBAL "
                "state messages with shared_state will not be processed correctly. "
            )
        self.per_stream_states = per_stream_states

    def get_stream_state(self, stream_name: str, namespace: Optional[str]) -> MutableMapping[str, Any]:
        """
        Retrieves the state of a given stream based on its descriptor (name + namespace).
        :param stream_name: Name of the stream being fetched
        :param namespace: Namespace of the stream being fetched
        :return: The per-stream state for a stream
        """
        stream_state = self.per_stream_states.get(HashableStreamDescriptor(name=stream_name, namespace=namespace))
        if stream_state:
            return stream_state.dict()  # type: ignore # mypy thinks dict() returns any, but it returns a dict
        return {}

    def update_state_for_stream(self, stream_name: str, namespace: Optional[str], value: Mapping[str, Any]) -> None:
        """
        Overwrites the state blob of a specific stream based on the provided stream name and optional namespace
        :param stream_name: The name of the stream whose state is being updated
        :param namespace: The namespace of the stream if it exists
        :param value: A stream state mapping that is being updated for a stream
        """
        stream_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace)
        self.per_stream_states[stream_descriptor] = AirbyteStateBlob.parse_obj(value)

    def create_state_message(self, stream_name: str, namespace: Optional[str]) -> AirbyteMessage:
        """
        Generates an AirbyteMessage using the current per-stream state of a specified stream
        :param stream_name: The name of the stream for the message that is being created
        :param namespace: The namespace of the stream for the message that is being created
        :return: The Airbyte state message to be emitted by the connector during a sync
        """
        hashable_descriptor = HashableStreamDescriptor(name=stream_name, namespace=namespace)
        stream_state = self.per_stream_states.get(hashable_descriptor) or AirbyteStateBlob()

        return AirbyteMessage(
            type=MessageType.STATE,
            state=AirbyteStateMessage(
                type=AirbyteStateType.STREAM,
                stream=AirbyteStreamState(
                    stream_descriptor=StreamDescriptor(name=stream_name, namespace=namespace), stream_state=stream_state
                ),
            ),
        )

    @classmethod
    def _extract_from_state_message(
        cls,
        state: Optional[List[AirbyteStateMessage]],
    ) -> Tuple[Optional[AirbyteStateBlob], MutableMapping[HashableStreamDescriptor, Optional[AirbyteStateBlob]]]:
        """
        Takes an incoming list of state messages or a global state message and extracts state attributes according to
        type which can then be assigned to the new state manager being instantiated
        :param state: The incoming state input
        :return: A tuple of shared state and per stream state assembled from the incoming state list
        """
        if state is None:
            return None, {}

        is_global = cls._is_global_state(state)

        if is_global:
            global_state = state[0].global_  # type: ignore # We verified state is a list in _is_global_state
            shared_state = copy.deepcopy(global_state.shared_state, {})
            streams = {
                HashableStreamDescriptor(
                    name=per_stream_state.stream_descriptor.name, namespace=per_stream_state.stream_descriptor.namespace
                ): per_stream_state.stream_state
                for per_stream_state in global_state.stream_states
            }
            return shared_state, streams
        else:
            streams = {
                HashableStreamDescriptor(
                    name=per_stream_state.stream.stream_descriptor.name, namespace=per_stream_state.stream.stream_descriptor.namespace
                ): per_stream_state.stream.stream_state
                for per_stream_state in state
                if per_stream_state.type == AirbyteStateType.STREAM and hasattr(per_stream_state, "stream")  # type: ignore # state is always a list of AirbyteStateMessage if is_per_stream is True
            }
            return None, streams

    @staticmethod
    def _is_global_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]]) -> bool:
        return (
            isinstance(state, List)
            and len(state) == 1
            and isinstance(state[0], AirbyteStateMessage)
            and state[0].type == AirbyteStateType.GLOBAL
        )

    @staticmethod
    def _is_per_stream_state(state: Union[List[AirbyteStateMessage], MutableMapping[str, Any]]) -> bool:
        return isinstance(state, List)
